import pickle
import matplotlib as plt
import os
import numpy as np
import matplotlib.pyplot as plt
import json
from itertools import combinations

from transformers import AutoTokenizer, AutoModel
import torch
import os
import sys

from transformers import pipeline

interpolation_dict_file_name = "interpolation_dict.json"
other_generations_folder_name = "other_generations"
output_file_name = "llama_prompt_generations.pkl"

n_samples_list = [1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
int_types = ['vanilla_linear']

dataset_name = 'reddit'
data_folder_name = "data_split_reddit"

def get_llama_response(prompt: str) -> None:
    """
    Generate a response from the Llama model.

    Parameters:
        prompt (str): The user's input/question for the model.

    Returns:
        None: Prints the model's response.
    """
    sequences = llama_pipeline(
        prompt,
        do_sample=True,
        top_k=20,
        top_p=0.9,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        temperature = 1.0,
        max_length=4096,
    )
    len_prompt = len(prompt)
    print(sequences[0]['generated_text'])
    answer = sequences[0]['generated_text'][len_prompt:]
    out_list = [text[2:] for text in answer.split('\n') if len(text.strip()) != 0]
    return out_list

def query_examples_llama(sample_corpus, n_examples):

    user_query = "Here are some samples of text:\n"

    for text in sample_corpus:
        user_query += "- " +text + '\n'

    user_query += "Write "+str(n_examples)+" more text samples in the same style and of approximately equal length. Use a dash (-) before each generated sample.\n"

    llama_response = get_llama_response(user_query)

    while len(llama_response) == 0:
        llama_response = get_llama_response(user_query)
    print("RESPONSE:", llama_response)
    return llama_response



if __name__ == "__main__":

    test_corpus_name = sys.argv[1]

    model = "../llama-chat/Llama-2-7b-chat-hf" 
    tokenizer = AutoTokenizer.from_pretrained(model, use_auth_token=True)

    llama_pipeline = pipeline(
    "text-generation",  # LLM task
    model=model,
    torch_dtype=torch.float16,
    device_map="auto",
    )

    final_results_folder = os.path.join('final_results_analyzed', 'final_results', dataset_name)

    print(test_corpus_name)

    test_data_folder = data_folder_name+"/test/"+test_corpus_name
    train_data_file = [file for file in os.listdir(test_data_folder) if 'train' in file][0]
    train_corpus = open(os.path.join(test_data_folder, train_data_file), 'r').read().split('\n\n')[:-1]
    interpolation_file_path = os.path.join(final_results_folder, test_corpus_name, int_types[0], interpolation_dict_file_name)
    with open(interpolation_file_path, 'r') as file:
        interpolation_dict = json.load(file)
    possible_interpolations = list(interpolation_dict.keys())
    if not os.path.exists(os.path.join(final_results_folder, test_corpus_name, other_generations_folder_name)):
        os.mkdir(os.path.join(final_results_folder, test_corpus_name, other_generations_folder_name))
    for interpolation_id in range(len(possible_interpolations)):
        current_interpolation = possible_interpolations[interpolation_id]
        print(current_interpolation)
        if not os.path.exists(os.path.join(final_results_folder, test_corpus_name, other_generations_folder_name, current_interpolation)):
            os.mkdir(os.path.join(final_results_folder, test_corpus_name, other_generations_folder_name, current_interpolation))
        model_names, corpus_ids = interpolation_dict[current_interpolation]['model_names'], interpolation_dict[current_interpolation]['samples_list']
        texts = [train_corpus[i] for i in corpus_ids]
        curr_output_file_path = os.path.join(final_results_folder, test_corpus_name, other_generations_folder_name, current_interpolation, output_file_name)
        chatgpt_generations = []
        for i in range(len(n_samples_list)):
             print(n_samples_list[i])
             out = query_examples_llama(texts[:n_samples_list[i]], 10)
             chatgpt_generations.append([text.strip() for text in out if len(text) != 0])
        with open(curr_output_file_path, 'wb') as file:
             pickle.dump(chatgpt_generations, file)
